from PIL import Image, ImageDraw
import json
import os
import re
import pdb
from tqdm import tqdm
import random
import argparse
import jsonlines
import ast

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np


imgs_dir =  "Your image path"
anno_dir = "Your annotation path"


_MIND2WEB_SYSTEM_ADD_DESCRIPTION_THINKING = """You are an assistant trained to navigate the web. 
Given a task instruction, a screenshot, and a last history action summary, output the think and ext action and wait for the next observation. 
The think must strictly follow these reasoning steps:
(1) Progress Estimation: Interface Comprehension and Progress Estimation
(2) Decesion Reasoning: Strategy Formulation
(3) History Summary: Update the history action summary according to the last history action summary and the action you executed.

## Action Space
1. `CLICK`: Click on an element, value is the element to click and the position [x,y] is required.
2. `TYPE`: Type a string into an element, value is the string to type and the position [x,y] is required.

## Output Format
<Progress Estimation>
...
</Progress Estimation>
<Decesion Reasoning>
...
</Decesion Reasoning>
<answer>
{{'action': 'ACTION_TYPE', 'value': 'element', 'position': [x,y]}}
</answer>
<History Summary>
...
</History Summary>

If value or position is not applicable, set it as `None`.
Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.
"""

def normalize_bbox(bbox, size):
    x1, y1, x2, y2 = bbox
    width, height = size
    
    x1_norm = x1 / width
    y1_norm = y1 / height
    x2_norm = x2 / width
    y2_norm = y2 / height
    return [x1_norm, y1_norm, x2_norm, y2_norm]


def get_bbox(action, image_size):
    bbox = [action["bbox"]["x"], action["bbox"]["y"], action["bbox"]["x"] + action["bbox"]["width"],
            action["bbox"]["y"] + action["bbox"]["height"]]
    bbox = [bbox[0] / image_size[0], bbox[1] / image_size[1], bbox[2] / image_size[0], bbox[3] / image_size[1]]
    bbox = [round(item, 3) for item in bbox]
    return bbox

def get_value(step_repr):
    pattern = r'\]\s+(.*?)\s+->'
    match = re.search(pattern, step_repr)
    if match:
        return match.group(1)
    else:
        return None

def get_answer(step):
    action_type = step['action_type']
    action_meta = step['action_meta']

    click_point = None
    type_text = None
    if action_type in ['click']:
        click_point = [(action_meta[0] + action_meta[2]) / 2, (action_meta[1] + action_meta[3]) / 2]
        click_point = [round(item, 2) for item in click_point]
    else:
        type_text = step['action_meta']

    answer = {'action': action_type.upper(), 'value': type_text, 'position': click_point}
    return answer

def data_transform(version='train', mini=False):
    miniwob = json.load(open(f"{anno_dir}/miniwob_data_{version}.json", 'r'))

    total_step = []
    step_i = 0
    nummm = 0
    total_step_key = []
    for scenario, scenario_data in miniwob.items():
        for episode in tqdm(scenario_data):
            previous_actions = []
            previous_images = []
            for idx, step in enumerate(episode):
                filename = step['img_filename']
                img_path = os.path.join(imgs_dir, filename)
                goal = step['goal']

                if not os.path.exists(img_path):
                    continue
                with Image.open(img_path) as image:
                    if step['action_type'] == 'click':
                        action_meta = normalize_bbox(step['bbox'], image.size)
                    elif step['action_type'] == 'type':
                        action_meta = step['typed_text']
                    else:
                        print(step)
                    
                    tmp_step = {
                        "img_url": filename,
                        "action_type": step['action_type'],
                        "action_meta": action_meta,
                    }

                    previous_step = ""
                    for i, action in enumerate(previous_actions):
                        previous_step += 'Step' + str(i) + ', previous action: ' + action[:-1] + "}. "

                    action_history = []
                    num_history = 4
                    for i, action in enumerate(previous_actions[-num_history:]):                         
                        action_history.append({"type": "text", "text": f'Step {i}: {action}' })

                    answer_dict = get_answer(tmp_step)
                    cur_answer = str(answer_dict)


                    prompt = _MIND2WEB_SYSTEM_ADD_DESCRIPTION_THINKING 

                    next_id = step_i + 1 if idx != len(episode)-1 else step_i
                    data = {
                        "id": "MiniWob_train_{}".format(step_i),
                        "step_id": step_i,
                        "image": img_path,
                        "problem": prompt,
                        "solution": cur_answer,
                        "task": goal,
                        "history": previous_step,
                        "bbox_ref": action_meta,
                        "next_id": next_id,
                        "is_last": idx == len(episode)-1,
                        "is_first": len(previous_actions) == 0,
                    }
                    previous_actions.append(cur_answer)
                    total_step.append(data)

                    step_i += 1
    # import ipdb; ipdb.set_trace()

    return total_step

if __name__ == "__main__":
    for version in ['train']:
        train_step = data_transform(version=version)
    
    save_url = "Your save path"
    with jsonlines.open(save_url, mode="w") as writer:
        writer.write_all(train_step)

    # test_full = []
    # for version in ['test_task', 'test_domain', 'test_website']:
    #     test_full.extend(data_transform(version=version))
    # save_url = "Your save path"
    # with jsonlines.open(save_url, mode="w") as writer:
    #     writer.write_all(test_full)